Bayesian inference in

noshadow

Marie-Hélène Burle

February 25, 2025


On probabilities

Two interpretations of probabilities


Frequentist

Bayesian

Frequentist

Frequentist approach to probabilities: assigns probabilities to the long-run frequency of events

For events for which there is no long-run, you imagine alternative realities and consider the frequency of occurrences in all those realities

The frequentist approach is computationally simple and fast and returns summary statistics

Bayesian

Bayesian approach: assigns probabilities to our beliefs about an event

Based on Bayes’ theorem of conditional probabilities which allows to calculate the probability of a cause given its effect:

\[ P(A \vert X) = \frac{P(X \vert A) P(A)}{P(X)} \]

where:

  • \(P(A)\) is the prior probability of \(A\)—our belief about event \(A\)
  • \(P(X)\) is the marginal probability of event \(X\) (some observed data)
  • \(P(X \vert A)\) is the likelihood or conditional probability of observing \(X\) given \(A\)
  • \(P(A \vert X)\) is the posterior probability—our updated belief about \(A\) given the data

Which approach to choose?

Bayesian statistics:

  • is more intuitive to the way we think about the world (easier to interpret)
  • allows for the incorporation of prior information and diverse data
  • is more informative as it provides a measure of uncertainty (returns probabilities)
  • is extremely valuable when there is little data (the inference is unstable and frequentist estimates have large variance and confidence intervals)

But beyond extremely simple examples, Bayesian inference is mathematically arduous

It is also much more computationally heavy and only became possible to apply widely with the advent of powerful computers and new algorithms such as Markov chain Monte Carlo (MCMC)

Bayesian computing

Algorithms

Markov chain Monte Carlo (MCMC) is a class of sampling algorithms which explore the posterior distribution

Different algorithms move in different ways across the N-dimensional space of the parameters, accepting or rejecting each new position based on its adherence to the prior distribution and the data

The sequence of accepted positions constitute the traces

Probabilistic Programming Language

Probabilistic programming language (PPL), explained simply in this (a bit outdated) blog post, are computer languages specialized in creating probabilistic models and making inference

Model components are first-class primitives

They can be based on a general programming language (e.g. Python, Julia) or domain specific

First Bayesian PPLs

Relied on Gibbs sampling:

BUGS = Bayesian inference Using Gibbs Sampling
JAGS = Just Another Gibbs Sampler

Stan

Stan (see also website and paper) is a domain-specific language

Stan scripts can be executed from R, Python, or the shell via RStan, PyStan, etc.

Also used as the backend for the R package brms which doesn’t require learning Stan but only works for simple models

Relies on No-U-Turn sampler (NUTS), a variant of Hamiltonian Monte Carlo (HMC) (see also HMC paper)

HMC and variants require burdensome calculations of derivatives. Stan solved that by creating its own reverse-mode automatic differentiation engine

Superior to Gibbs sampler ➔ made Stan a very popular PPL for years

PPLs based on deep learning frameworks

Since HMC and NUTS require autodiff, many Python PPLs have emerged in recent years, following the explosion of deep learning

Examples:

Enters JAX


Had JAX existed when we started coding Stan in 2011, we would’ve used that rather than rolling our own autodiff system.

Bob Carpenter, one of Stan’s creators, in a recent blog post

What is JAX?

JAX is a library for Python that:

  • makes use of the extremely performant XLA compiler
  • runs on accelerators (GPUs/TPUs)
  • provides automatic differentiation
  • uses just-in-time compilation
  • allows batching and parallelization

⇒ perfect tool for Bayesian statistics


See our introductory JAX course and webinar for more details

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Transformations py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Vectorization Parallelization   Differentiation   py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit jit hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform vmap pmap grad py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

JAX idiosyncrasies

JAX is sublanguage of Python requiring pure functions instead of Python’s object-oriented style

It has other quirks

The only one you really need to understand for use in PPLs is the pseudorandom number generation

PRNG keys

Traditional pseudorandom number generators are based on nondeterministic state of the OS. This is slow and problematic for parallel executions

JAX relies on an explicitly-set random state called a key:

from jax import random
key = random.key(18)

Each key can only be used for one random function, but it can be split into new keys:

key, subkey = random.split(key)

The first key can’t be used anymore. We overwrote it with a new key to ensure we don’t accidentally reuse it

We can now use subkey in random functions in our code (and keep key to generate new subkeys as needed)

JAX use cases

New JAX backends added to many PPLs

Edward2 and TensorFlow Probability can now use JAX as backend

PyMC relies on building a static graph. It is based on PyTensor which provides JAX compilation (PyTensor is a fork of aesara, itself a fork of Theano)

NumPyro

NumPyro is a library based on Pyro but using NumPy and JAX

Blackjax

Not a PPL but a library of MCMC samplers built on JAX

Can be used directly if you want to define your own log-probability density functions or can be used with several PPLs to define your model (make sure to translate it to a log-probability function)

Also provides building blocks for experimentation with new algorithms

Blackjax

Example Blackjax sampler: HMC

Example Blackjax sampler: NUTS

Which tool to choose?

All these tools are in active development (JAX was released and started shaking the field in 2018). Things are fast evolving. Reading blogs of main developers, posts on Hacker News, discourse forums, etc. helps to keep an eye on evolutions in the field

This recent conversation between Bob Carpenter (Stan core developer) and Ricardo Vieira (PyMC core developer) in the PyMC discourse forum is interesting

A lot of it also comes down to user preferences

Resources

How to get started with Bayesian computing?

The book Probabilistic Programming & Bayesian Methods for Hackers by Cameron Davidson-Pilon provides a code-based (using PyMC) and math-free introduction to Bayesian methods for the real beginner

Several resources on the PyMC website including intro Bayesian with PyMC

NumPyro tutorials

More advanced: tutorials from Blackjax Sampling Book Project

How to transition from Stan to a JAX-based PPL?

The code to the classic Bayesian textbook Statistical Rethinking by Richard McElreath got translated by various people to modern JAX-based PPLs: